This RMarkdown file contains the report of the data analysis done for the project on building and deploying a stroke prediction model in R. It contains analysis such as data exploration, data visualization, statistical/epidemiological analysis and building the prediction models. The final report was completed on Sat Apr 5 15:12:04 2025.
Data Description:
According to the World Health Organization (WHO) stroke is the 2nd leading cause of death globally, responsible for approximately 11% of total deaths.
This data set is used to predict whether a patient is likely to get stroke based on the input parameters like gender, age, various diseases, and smoking status. Each row in the data provides relevant information about the patient.
### we use pacman to install and load the required packages
if (!require("pacman")) install.packages("pacman")
pacman::p_load("caret", "data.table", "DescTools", "epitools", "GGally", "ggplot2", "gridExtra", "mlbench", "mltools", "naniar", "parsnip", "pROC", "ranger", "reshape2", "recipes", "rsample", "shiny", "smotefamily", "themis","tidymodels", "tune", "viridis", "workflows", "yardstick", "xgboost")
source("utils.R")
Read the data and check its dimensions:
dat <- as.data.frame(read.csv("healthcare-dataset-stroke-data.csv"))
cat("There are", nrow(dat), "samples and", ncol(dat), "input variables in the stroke data.")
## There are 5110 samples and 12 input variables in the stroke data.
What are the types of these variables?
sapply(dat, class)
## id gender age hypertension
## "integer" "character" "numeric" "integer"
## heart_disease ever_married work_type Residence_type
## "integer" "character" "character" "character"
## avg_glucose_level bmi smoking_status stroke
## "numeric" "character" "character" "integer"
To this end, we can also use str function, it outputs types of column as well as an overview of some values:
str(dat)
## 'data.frame': 5110 obs. of 12 variables:
## $ id : int 9046 51676 31112 60182 1665 56669 53882 10434 27419 60491 ...
## $ gender : chr "Male" "Female" "Male" "Female" ...
## $ age : num 67 61 80 49 79 81 74 69 59 78 ...
## $ hypertension : int 0 0 0 0 1 0 1 0 0 0 ...
## $ heart_disease : int 1 0 1 0 0 0 1 0 0 0 ...
## $ ever_married : chr "Yes" "Yes" "Yes" "Yes" ...
## $ work_type : chr "Private" "Self-employed" "Private" "Private" ...
## $ Residence_type : chr "Urban" "Rural" "Rural" "Urban" ...
## $ avg_glucose_level: num 229 202 106 171 174 ...
## $ bmi : chr "36.6" "N/A" "32.5" "34.4" ...
## $ smoking_status : chr "formerly smoked" "never smoked" "never smoked" "smokes" ...
## $ stroke : int 1 1 1 1 1 1 1 1 1 1 ...
Most of the variables are categorical. We convert them into factor.
dat <- dat[, -1]
dat[, -c(2, 8, 9)] <- lapply(dat[, -c(2, 8, 9)], as.factor)
dat$bmi <- as.numeric(dat$bmi)
Let us have a glimpse of the data:
head(dat)
## gender age hypertension heart_disease ever_married work_type
## 1 Male 67 0 1 Yes Private
## 2 Female 61 0 0 Yes Self-employed
## 3 Male 80 0 1 Yes Private
## 4 Female 49 0 0 Yes Private
## 5 Female 79 1 0 Yes Self-employed
## 6 Male 81 0 0 Yes Private
## Residence_type avg_glucose_level bmi smoking_status stroke
## 1 Urban 228.69 36.6 formerly smoked 1
## 2 Rural 202.21 NA never smoked 1
## 3 Rural 105.92 32.5 never smoked 1
## 4 Urban 171.23 34.4 smokes 1
## 5 Rural 174.12 24.0 never smoked 1
## 6 Urban 186.21 29.0 formerly smoked 1
To analyze the missing values we use naniar package. The following plot shows that 4 percent of bmi values, or 201 of observations have bmi values missing.
vis_miss(dat)
Are these missing values distributed randomly? To this end, we look at the distribution of missing values with respect to the variables of interest. For instance, regarding gender, there are more missing values for males than for females, and regarding the output variable of interest, almost ~16 percent of BMI values for stroke individuals are missing which is at most 4 percent for non-stroke individuals. So these values are missing at random.
p1 <- plot_missingness_distribution(dat, "gender")
p2 <- plot_missingness_distribution(dat, "stroke")
grid.arrange(p1, p2, ncol = 2)
The following plots show the distribution of categorical variables.
# Select categorical variables and stroke column
cat_vars <- dat[, -c(2, 8, 9)]
stroke <- dat$stroke
# convert into long format
long_dat <- cat_vars %>% pivot_longer(cols = everything(), names_to = "variable", values_to = "value")
# change the order of categorical variables for plot
long_dat$variable <- factor(long_dat$variable, levels = c("stroke", "gender", "hypertension", "heart_disease", "smoking_status", "ever_married", "work_type", "Residence_type"))
# bar plot of percentages
ggplot(long_dat, aes(x = value, y = after_stat(prop), fill=variable)) +
geom_bar(position = position_dodge(), stat = "prop")+
geom_text(aes(label = round(100 * after_stat(prop), 2)),
position = position_dodge(.9), stat = "prop", vjust = -.2
) +
facet_wrap(~ variable, nrow = 2, ncol = 4, scales = "free_x") +
scale_fill_brewer(palette = "Set2") +
theme_minimal(base_size = 14) +
theme(
plot.title = element_text(hjust = 0.5, size = 16),
axis.text.x = element_text(angle = 45, hjust = 1),
strip.text = element_text(size = 12),
panel.spacing = unit(1.2, "lines"),
legend.position="none"
) +
labs(
x = "",
y = "Proportion of observations",
title = ""
)
We see that in this study there are more females than males, more individuals with positive ever-married status, more individuals working in the private sector than the remaining sectors. There are almost the same number of observations in rural and urban categories. Below we look at the prevalence of stroke cases across all categorical variables.
cat_vars <- dat[, -c(2, 8, 9, 11)] # remove continuous vars or unwanted cols
cat_vars$stroke <- dat$stroke # keep stroke indicator
# convert to long format
long_dat <- cat_vars %>%
pivot_longer(cols = -stroke, names_to = "variable", values_to = "value")
# stroke prevalence per category
stroke_pct <- long_dat %>%
group_by(variable, value) %>%
summarise(
total = n(),
stroke_cases = sum(stroke == 1),
pct = round(100 * stroke_cases / total, 2),
.groups = "drop"
)
# set the plot order
stroke_pct$variable <- factor(stroke_pct$variable,
levels = c("gender", "hypertension", "heart_disease",
"smoking_status", "ever_married",
"work_type", "Residence_type"))
ggplot(stroke_pct, aes(x = value, y = pct, fill = variable)) +
geom_bar(stat = "identity") +
facet_wrap(~variable, nrow = 2, ncol = 4, scales = "free_x") +
scale_fill_brewer(palette = "Set2") +
theme_minimal(base_size = 14) +
theme(plot.title = element_text(hjust = 0.5, size = 16),
axis.text.x = element_text(angle = 45, hjust = 1),
panel.spacing = unit(1.2, "lines"),
legend.position = "none") +
labs(x = "", y = "Stroke rate (%)", title = "")
We see that both females and males are equally affected by this medical condition. There are more stroke cases among individuals with hypertension, and heart disease. More formerly-smoked individuals have stroke than the rest, however, there are a lot of individuals with the unknown smoking status, which makes this analysis biased. In general these variables are related to the age. For instance, more formerly-smoked individuals, or ever-married individuals are older people.
We use Cramer’s V to measure associations between multiple categories. It ranges from 0 (no association) to 1 (perfect association). The following heatmap shows that there is some association between smoking status and marriage status, smoking status and work type. This is likely due to the age of individuals as a confounding variable. Most of the associations, however, are weak.
cat_vars <- dat[, -c(2, 8, 9, 11)] %>%
mutate(across(everything(), ~ droplevels(as.factor(.))))
cramerV_matrix <- compute_cramer_v_matrix(cat_vars)
cramerV_tidy <- melt(cramerV_matrix, na.rm = TRUE)
cramerV_tidy$Var1 <- factor(cramerV_tidy$Var1, levels = rownames(cramerV_matrix))
cramerV_tidy$Var2 <- factor(cramerV_tidy$Var2, levels = colnames(cramerV_matrix))
ggplot(cramerV_tidy, aes(Var1, Var2, fill = value)) +
geom_tile(color = "white") +
scale_fill_gradient2(
low = "white", mid = "lightblue", high = "steelblue",
midpoint = 0.2, limits = c(0, 1), name = "Cramér’s V"
) +
geom_text(aes(label = round(value, 2)), color = "black", size = 3) +
theme_minimal(base_size = 14) +
theme(
axis.text.x = element_text(angle = 45, hjust = 1),
panel.grid = element_blank()
) +
labs(title = "", x = "", y = "" )
What about continuous variables: age, average glucose level, and body mass index?
The majority of individuals have an average glucose level below 150, but there is a small group of individuals whose average glucose level is concentrated around ~210.
p1 <- ggplot(dat, aes(x=avg_glucose_level)) +
geom_histogram(fill="#92C5DE")+
labs(x = "Average glucose level", y = "") +
theme_minimal(base_size = 12)+
theme(legend.position="none")
p2 <- ggplot(dat, aes(x=bmi)) +
geom_histogram(fill="#E69F00")+
labs(x = "Body mass index", y = "") +
theme_minimal(base_size = 12)+
theme(legend.position="none")
p3 <- ggplot(dat, aes(x=age)) +
geom_histogram(fill="#D6604D")+
labs(x = "Age", y = "") +
theme_minimal(base_size = 12)+
theme(legend.position="none")
grid.arrange(p1, p2, p3, ncol=3)
The distribution of average glucose level among individuals who experienced a stroke is visibly multimodal, with a secondary peak above 200. The median average glucose level among stroke cases is around 105. While individuals with average glucose levels above ~ 200mg/dL constitute a small portion of observations, they account for approximately 25% of all stroke cases, as the following plot shows.
p1 <- ggplot(dat, aes(x = stroke, y = avg_glucose_level, color=stroke)) +
geom_boxplot(outlier.shape = NA) +
geom_jitter(width = 0.2, alpha = 0.6, size = 1.5)+
scale_fill_brewer(palette = "Set3") +
scale_color_brewer(palette = "Set2") +
labs(x = "Stroke", y = "Average glucose level", color = "Stroke") +
theme_minimal(base_size = 14) +
theme(legend.position = "none")
agl_means <- dat %>% group_by(stroke) %>% summarise(agl_mean = median(avg_glucose_level))
p2 <- ggplot(dat, aes(x=avg_glucose_level, fill=stroke)) + geom_density(alpha=0.4) +
geom_vline(data = agl_means, aes(xintercept=agl_mean, color=stroke), linetype="dashed")+
labs(x = "Average glucose level", y = "Density", color = "stroke") +
scale_fill_brewer(palette = "Set3") +
scale_color_brewer(palette = "Set2") +
theme_minimal(base_size = 14)
#+
#theme(legend.position = "none")
grid.arrange(p1, p2, ncol=2)
The distribution of body mass index of stroke individuals is
concentrated around the mean ~27. For BMI variable there is much
overlap between the stroke and non-stroke cases.
p1 <- ggplot(dat, aes(x = stroke, y = bmi, color=stroke)) +
geom_boxplot(outlier.shape = NA) +
geom_jitter(width = 0.2, alpha = 0.6, size = 1.5)+
scale_fill_brewer(palette = "Set3") +
scale_color_brewer(palette = "Set2") +
labs(x = "Stroke", y = "Body mass index", color = "Stroke") +
theme_minimal(base_size = 14) +
theme(legend.position = "none")
bmi_means <- dat %>% filter(!is.na(dat$bmi)) %>% group_by(stroke) %>% summarise(bmi_mean = mean(bmi))
p2 <- ggplot(dat, aes(x=bmi, fill=stroke)) + geom_density(alpha=0.4)+
geom_vline(data = bmi_means, aes(xintercept=bmi_mean, color=stroke), linetype="dashed")+
labs(x = "Body mass index", y = "Density", color="stroke") +
scale_fill_brewer(palette = "Set3") +
scale_color_brewer(palette = "Set2") +
theme_minimal(base_size = 14)
grid.arrange(p1, p2, ncol=2)
The following plots show that the age is clearly a risk factor for stroke: 75% of individuals with stroke are aged above ~60, whereas 75% of individuals with no stroke are aged below ~60. For stroke individuals, the distribution is left skewed with half of strokes happening above the age 70. There are no stroke cases between the ages 20 and 30 in the data, but there are 2 children with stroke. The cases involving children can be considered special cases.
p1 <- ggplot(dat, aes(x = stroke, y = age, color=stroke)) +
geom_boxplot(outlier.shape = NA) +
geom_jitter(width = 0.2, alpha = 0.6, size = 1.5)+
scale_fill_brewer(palette = "Set3") +
scale_color_brewer(palette = "Set2") +
labs(x = "Stroke", y = "Age", color = "Stroke") +
theme_minimal(base_size = 14) +
theme(legend.position = "none")
age_means <- dat %>% group_by(stroke) %>% summarise(age_mean = mean(age))
p2 <- ggplot(dat, aes(x=age, fill=stroke)) + geom_density(alpha=0.4)+
geom_vline(data = age_means, aes(xintercept=age_mean, color=stroke), linetype="dashed")+
labs(x = "Age", y = "Density", color = "Stroke") +
scale_fill_brewer(palette = "Set3") +
scale_color_brewer(palette = "Set2") +
theme_minimal(base_size = 14) +
theme(legend.position = "none")
grid.arrange(p1, p2, ncol=2)
The relationship between average glucose level and heart disease.
p1 <- ggplot(dat, aes(x = heart_disease, y = avg_glucose_level, color=heart_disease)) +
geom_boxplot(outlier.shape = NA) +
geom_jitter(width = 0.2, alpha = 0.6, size = 1.5)+
scale_fill_brewer(palette = "Set3") +
scale_color_brewer(palette = "Set2") +
labs(x = "Heart disease", y = "Average glucose level", color = "heart_disease") +
theme_minimal(base_size = 14) +
theme(legend.position = "none")
agl_means <- dat %>% group_by(heart_disease) %>% summarise(agl_mean = median(avg_glucose_level))
p2 <- ggplot(dat, aes(x=avg_glucose_level, fill=heart_disease)) + geom_density(alpha=0.4) +
geom_vline(data = agl_means, aes(xintercept=agl_mean, color=heart_disease), linetype="dashed")+
labs(x = "Average glucose level", y = "Density", color = "heart_disease") +
scale_fill_brewer(palette = "Set3") +
scale_color_brewer(palette = "Set2") +
theme_minimal(base_size = 14)
grid.arrange(p1, p2, ncol=2)
Next, we analyze pairwise relationships between these variables. For stroke individuals, there is a negative correlation between the age and the body mass index. This association is positive for non-stroke cases. The association between the average glucose level and body mass index is higher than it is in the non-stroke case.
ggpairs(dat, columns=c(2, 8, 9), aes(color=stroke, alpha=0.3),
lower=list(continuous="smooth"), diag=list(continuous="densityDiag"))+
scale_fill_brewer(palette = "Set3") +
scale_color_brewer(palette = "Set2")
75% of the hypertension individuals are above the age of ~52 and half of the hypertension individuals are above the age of ~62. Stroke appears later in life, regardless of hypertension, however the hypertension individuals who experienced stroke are generally older individuals.
p1 <- ggplot(dat, aes(x = hypertension, y = age, fill = hypertension)) +
geom_boxplot(outlier.shape = NA, position = position_dodge(0.8)) +
geom_jitter(aes(color=hypertension), width=0.2, alpha = 0.5, size = 1.2) +
scale_fill_brewer(palette = "Set2") +
scale_color_brewer(palette = "Set2") +
labs(x = "Hypertension", y = "Age", title = "Age distribution by hypertension") +
theme_minimal(base_size = 14)+
theme(plot.title = element_text(hjust = 0.5, size = 16))
p2 <- ggplot(dat, aes(x = hypertension, y = age, fill = stroke)) +
geom_boxplot(outlier.shape = NA, position = position_dodge(0.8)) +
geom_jitter(aes(color = stroke),
position = position_jitterdodge(jitter.width = 0.2, dodge.width = 0.8),
alpha = 0.5, size = 1.2) +
scale_fill_brewer(palette = "Set3") +
scale_color_brewer(palette = "Set2") +
labs(x = "Hypertension", y = "Age", fill = "Stroke", color = "Stroke",
title = "Age distribution by hypertension and stroke") +
theme_minimal(base_size = 14)+
theme(plot.title = element_text(hjust = 0.5, size = 16))
grid.arrange(p1, p2, ncol=2)
The median age for heart disease is higher than that for hypertension: above 70. Similarly, stroke appears later in life, regardless of heart disease, however there is less variability in the age distribution of individuals with heart disease and stroke: they are predominantly older.
p1 <- ggplot(dat, aes(x = heart_disease, y = age, fill = heart_disease)) +
geom_boxplot(outlier.shape = NA, position = position_dodge(0.8)) +
geom_jitter(aes(color=heart_disease), width=0.2, alpha = 0.5, size = 1.2) +
scale_fill_brewer(palette = "Set2") +
scale_color_brewer(palette = "Set2") +
labs(x = "Heart disease", y = "Age", title = "Age distribution by heart disease") +
theme_minimal(base_size = 14)+
theme(plot.title = element_text(hjust = 0.5, size = 16))
p2 <- ggplot(dat, aes(x = heart_disease, y = age, fill = stroke)) +
geom_boxplot(outlier.shape = NA, position = position_dodge(0.8)) +
geom_jitter(aes(color = stroke),
position = position_jitterdodge(jitter.width = 0.2, dodge.width = 0.8),
alpha = 0.5, size = 1.2) +
scale_fill_brewer(palette = "Set3") +
scale_color_brewer(palette = "Set2") +
labs(
x = "Heart disease",
y = "Age",
fill = "Stroke",
color = "Stroke",
title = "Age distribution by heart disease and stroke"
) +
theme_minimal(base_size = 14)+
theme(
plot.title = element_text(hjust = 0.5, size = 16)
)
grid.arrange(p1, p2, ncol=2)
The following plot shows the age distribution across combinations of hypertension and heart diseases. For instance 1_1 indicates the individuals who had both hypertension and heart disease. We separate the boxplots by stroke condition.
stroke_hypert_heartd <- dat %>%mutate(hyp_hd = interaction(hypertension, heart_disease, sep = "_"))
ggplot(stroke_hypert_heartd, aes(x = hyp_hd, y = age, fill = stroke)) +
geom_boxplot(position = position_dodge(width = 0.75), outlier.shape = NA) +
geom_jitter(aes(color = stroke),
position = position_jitterdodge(jitter.width = 0.2, dodge.width = 0.75),
alpha = 0.4, size = 1.2) +
scale_fill_brewer(palette = "Set3") +
scale_color_brewer(palette = "Set2") +
labs(
x = "Hypertension_heart disease",
y = "Age",
title = "Age distribution by hypertension & heart disease, split by stroke"
) +
theme_minimal(base_size = 14) +
theme(plot.title = element_text(hjust = 0.5, size = 16))
In this section, we will do some statistical analysis of association between variables.
To compute the odds ratio we use epitools package. The outcome of interest is stroke, and the exposures are: hypertension, heart disease, average glucose level, and age.
The following odds ratio computation shows that the odds of stroke in hypertensive individuals is ~3.7 times higher than that in non-hypertensive individuals.
tbl_hypertension <- table(dat$stroke, dat$hypertension)
oddsratio(tbl_hypertension)
## $data
##
## 0 1 Total
## 0 4429 432 4861
## 1 183 66 249
## Total 4612 498 5110
##
## $measure
## odds ratio with 95% C.I.
## estimate lower upper
## 0 1.000000 NA NA
## 1 3.701246 2.729655 4.964923
##
## $p.value
## two-sided
## midp.exact fisher.exact chi.square
## 0 NA NA NA
## 1 5.77316e-15 4.549182e-15 6.068123e-20
##
## $correction
## [1] FALSE
##
## attr(,"method")
## [1] "median-unbiased estimate & mid-p exact CI"
Also the odds of stroke in individuals with heart disease is ~4.71 times higher than that in individuals with no heart disease.
tbl_heart_disease <- table(dat$stroke, dat$heart_disease)
oddsratio(tbl_heart_disease)
## $data
##
## 0 1 Total
## 0 4632 229 4861
## 1 202 47 249
## Total 4834 276 5110
##
## $measure
## odds ratio with 95% C.I.
## estimate lower upper
## 0 1.000000 NA NA
## 1 4.714026 3.309098 6.601552
##
## $p.value
## two-sided
## midp.exact fisher.exact chi.square
## 0 NA NA NA
## 1 8.881784e-15 7.283093e-15 5.20011e-22
##
## $correction
## [1] FALSE
##
## attr(,"method")
## [1] "median-unbiased estimate & mid-p exact CI"
To do odds ratio analysis with average glucose level, we bin it into two groups at the level of 150. Individuals with average glucose level >150 have ~3.7 times higher odds of stroke than those with glucose ≤150.
stroke_agl <- dat %>% mutate(agl_group = cut(avg_glucose_level, breaks = c(0, 150, max(dat$avg_glucose_level)))) %>% select(c(stroke, agl_group))
tbl_agl <- table(stroke_agl$stroke, stroke_agl$agl_group)
oddsratio(tbl_agl)
## $data
##
## (0,150] (150,272] Total
## 0 4221 640 4861
## 1 159 90 249
## Total 4380 730 5110
##
## $measure
## odds ratio with 95% C.I.
## estimate lower upper
## 0 1.000000 NA NA
## 1 3.734437 2.836751 4.889724
##
## $p.value
## two-sided
## midp.exact fisher.exact chi.square
## 0 NA NA NA
## 1 0 6.607959e-19 5.174082e-24
##
## $correction
## [1] FALSE
##
## attr(,"method")
## [1] "median-unbiased estimate & mid-p exact CI"
Odds of stroke in individuals aged above 60 is ~8.13 times higher than that in indivduals aged less than 60. Thus we see that age is a very important risk factor. This result is statistically very significant, the confidence interval does not contain 1, as the two result reported above.
stroke_age <- dat %>% mutate(age_group = cut(age, breaks = c(0, 60, max(dat$age)), labels=c("0", "1"))) %>% select(c(stroke, age_group))
tbl_age <- table(stroke_age$stroke, stroke_age$age_group)
oddsratio(tbl_age)
## $data
##
## 0 1 Total
## 0 3734 1127 4861
## 1 72 177 249
## Total 3806 1304 5110
##
## $measure
## odds ratio with 95% C.I.
## estimate lower upper
## 0 1.000000 NA NA
## 1 8.130197 6.159177 10.83745
##
## $p.value
## two-sided
## midp.exact fisher.exact chi.square
## 0 NA NA NA
## 1 0 3.741383e-54 3.822487e-64
##
## $correction
## [1] FALSE
##
## attr(,"method")
## [1] "median-unbiased estimate & mid-p exact CI"
Now, how about the gender? There is no association between gender and the occurrence of stroke.
tbl_gender <- table(dat$stroke, dat$gender)
oddsratio(tbl_gender)
## $data
##
## Female Male Other Total
## 0 2853 2007 1 4861
## 1 141 108 0 249
## Total 2994 2115 1 5110
##
## $measure
## odds ratio with 95% C.I.
## estimate lower upper
## 0 1.000000 NA NA
## 1 1.089201 0.840725 1.407353
##
## $p.value
## two-sided
## midp.exact fisher.exact chi.square
## 0 NA NA NA
## 1 0.5160089 0.5746131 0.7895491
##
## $correction
## [1] FALSE
##
## attr(,"method")
## [1] "median-unbiased estimate & mid-p exact CI"
What percentage of the stroke cases in the hypertension individuals can be attributed to hypertension? To answer this question we need to compute attributable risk percent: 70% of strokes in hypertension individuals could be attributed to hypertension.
total_hypert <- sum(dat$hypertension==1)
total_no_hypert <- sum(dat$hypertension==0)
stroke_no_hypertension <- sum((dat$hypertension==0) & (dat$stroke==1))
stroke_with_hypertension <- sum((dat$hypertension==1) & (dat$stroke==1))
par <- round(((stroke_with_hypertension/total_hypert)-(stroke_no_hypertension/total_no_hypert))*100 / (stroke_with_hypertension/total_hypert))
cat("Attributable risk percent due to hypertension is", par, ".")
## Attributable risk percent due to hypertension is 70 .
What percentage of the stroke cases in individuals above the age of 60 can be attributed to their age? 86% of strokes in individuals above the age of 60 could be attributed to their age.
total_above_60 <- sum(stroke_age$age_group==1)
total_below_60 <- sum(stroke_age$age_group==0)
stroke_above_60 <- sum((stroke_age$age_group==1) & (dat$stroke==1))
stroke_below_60 <- sum((stroke_age$age_group==0) & (dat$stroke==1))
par <- round(((stroke_above_60/total_above_60)-(stroke_below_60/total_below_60))*100 / (stroke_above_60/total_above_60))
cat(paste0("Attributable risk percent due to being above 60 is ", par, "%."))
## Attributable risk percent due to being above 60 is 86%.
What percentage of the stroke cases in the data can be attributed to hypertension? To answer this question we need to compute population attributable risk percent. The computation below shows that only 19% of all stroke cases could be attributed to hypertension. In a similar vein, only 14% of all stroke cases could be attributed to heart disease.
n <- nrow(dat)
total_stroke <- sum(dat$stroke==1)
no_hypert <- sum(dat$hypertension==0)
stroke_no_hypertension <- sum((dat$hypertension==0) & (dat$stroke==1))
no_heart_disease <- sum(dat$heart_disease==0)
stroke_no_heart_disease <- sum((dat$heart_disease==0) & (dat$stroke==1))
par <- round(((total_stroke/n)-(stroke_no_hypertension/no_hypert))*100/(total_stroke/n))
par2 <- round(((total_stroke/n)-(stroke_no_heart_disease/no_heart_disease))*100/(total_stroke/n))
cat("Population attributable risk percent due to hypertension is", par, "and population attributable risk percent due to heart disease is", par2, ".")
## Population attributable risk percent due to hypertension is 19 and population attributable risk percent due to heart disease is 14 .
In the pairwise association plot above we saw that individuals with stroke tend to have lower body mass index, however with age, in general, people tend to have higher body mass index. Let us look at this relationship more closely:
p1 <- ggplot(dat, aes(x = age, y = bmi)) +
geom_point(alpha = 0.3, aes(color="#E69F00")) +
geom_smooth(method = "lm", se = TRUE, color='#E69F00') +
scale_fill_brewer(palette = "Set3") +
scale_color_brewer(palette = "Set2") +
labs(color = "stroke", title = "Age vs bmi")+
theme(plot.title = element_text(hjust = 0.5, size = 16), legend.position="none")
p2 <- ggplot(dat, aes(x = age, y = bmi, color = stroke, shape=stroke)) +
geom_point(alpha = 0.3) +
geom_smooth(method = "lm", se = TRUE) +
scale_fill_brewer(palette = "Set3") +
scale_color_brewer(palette = "Set2") +
labs(color = "stroke", title = "Age vs bmi by stroke status")+
theme(plot.title = element_text(hjust = 0.5, size = 16))
grid.arrange(p1, p2, ncol=2)
Is the interaction of age and stroke is statistically significant for bmi?
We see the body mass index on average increases by 0.12 units with each additional year for non-stroke cases, but for stroke cases it decreases by 0.12 - 0.2756 = -0.1556 for each additional year. However, note that, age and stroke explains a very little variance in body mass index: around ~12 percent.
lm_mod <- lm(bmi ~ age * stroke, data = dat)
summary(lm_mod)
##
## Call:
## lm(formula = bmi ~ age * stroke, data = dat)
##
## Residuals:
## Min 1Q Median 3Q Max
## -22.097 -4.998 -1.412 3.564 71.818
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 23.693664 0.228446 103.717 < 2e-16 ***
## age 0.122829 0.004827 25.446 < 2e-16 ***
## stroke1 17.125887 2.844665 6.020 1.87e-09 ***
## age:stroke1 -0.275655 0.041475 -6.646 3.33e-11 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 7.368 on 4905 degrees of freedom
## (201 observations effacées parce que manquantes)
## Multiple R-squared: 0.1204, Adjusted R-squared: 0.1198
## F-statistic: 223.8 on 3 and 4905 DF, p-value: < 2.2e-16
Logistic regression models the log odds of stroke as a linear combination of variables. It allows us to see the effect of each covariate on the odds of stroke while adjusting for all the other covariates. The result below shows that the age, average glucose level, and hypertension are significant covariates for stroke, but heart disease is not, after adjusting for other variables. In other words, older individuals, individuals with hypertension, and individuals with higher glucose levels are at increased risk of stroke.
lg_model <- glm(stroke ~ age + gender + hypertension + heart_disease + ever_married + work_type + Residence_type + avg_glucose_level + smoking_status, data = dat, family = "binomial")
summary(lg_model)
##
## Call:
## glm(formula = stroke ~ age + gender + hypertension + heart_disease +
## ever_married + work_type + Residence_type + avg_glucose_level +
## smoking_status, family = "binomial", data = dat)
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) -6.759e+00 7.527e-01 -8.980 < 2e-16 ***
## age 7.464e-02 5.729e-03 13.029 < 2e-16 ***
## genderMale 1.244e-02 1.418e-01 0.088 0.930114
## genderOther -1.054e+01 1.455e+03 -0.007 0.994221
## hypertension1 4.050e-01 1.644e-01 2.463 0.013779 *
## heart_disease1 2.791e-01 1.911e-01 1.461 0.144040
## ever_marriedYes -1.833e-01 2.254e-01 -0.814 0.415902
## work_typeGovt_job -9.298e-01 8.210e-01 -1.133 0.257393
## work_typeNever_worked -1.032e+01 3.095e+02 -0.033 0.973387
## work_typePrivate -7.877e-01 8.051e-01 -0.978 0.327907
## work_typeSelf-employed -1.165e+00 8.264e-01 -1.410 0.158683
## Residence_typeUrban 8.334e-02 1.383e-01 0.602 0.546897
## avg_glucose_level 4.053e-03 1.174e-03 3.451 0.000558 ***
## smoking_statusnever smoked -2.069e-01 1.759e-01 -1.176 0.239524
## smoking_statussmokes 1.121e-01 2.153e-01 0.521 0.602553
## smoking_statusUnknown -7.298e-02 2.084e-01 -0.350 0.726145
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 1990.4 on 5109 degrees of freedom
## Residual deviance: 1581.2 on 5094 degrees of freedom
## AIC: 1613.2
##
## Number of Fisher Scoring iterations: 14
To interpret the effect of age, hypertension, and average glucose level on stroke, we need to exponentiate the coefficients. We see that each additional year is associated with the ~7.7% increase in the odds of stroke. Each additional average glucose level is associated with the ~0.4% increase in the odds of stroke. People with hypertension have ~1.5 times the odds of stroke compared to those without hypertension.
coeffs <- lg_model$coefficients
coeffs_sign <- coeffs[names(coeffs) %in% c("age", "hypertension1", "avg_glucose_level")]
exp(coeffs_sign)
## age hypertension1 avg_glucose_level
## 1.077495 1.499276 1.004061
This is a problem with high class imbalance: only around 5 percent of observations are stroke cases. Thus it is important to choose relevant evaluation measures. Without any learning, without even considering any input variables, predicting no stroke would achieve a very high accuracy result. Thus we concentrate on sensitivity and specificity first. Sensitivity measures the proportion of individuals correctly identified as stroke cases among those who actually have a stroke, and specificity measures the proportion of individuals correctly identified as healthy among those who are actually healthy. These measures take values between O and 1. Now, if all individuals are predicted as non-stroke ones, then the specificity will be 1 since there is no false positives. However, sensitivity in this case will be 0, since 0 stroke case, i.e., 0 true positive is identified.
Sensitivity is also known as recall. Precision computes the fraction of true positives over the sum of true and false positives. The F1-measure is the harmonic mean of precision and recall, it balances precision and recall. But notice that if the specificity is 1 and the sensitivity is 0, then the precision is undefined, since there is neither true positives or false positives predicted. Thus, F1 measure is not defined.
Next, we consider the area under the ROC curve. This curve draws sensitivity against false positive rate (1-specificity) at every cut-off threshold level for predicted probability values, and the larger the area under this curve the better the predictive performance of algorithm.
Finally, we consider Matthews correlation coefficient. The value of this metric is computed based on the entire confusion matrix, i.e., it takes into account true positives, false positives, true negatives and false negatives. It can be considered the most suitable measure in the class-imbalance setting.
Now, since we are dealing with the class-imbalance problem, we can choose cut-off threshold values below the standard 0.5 level. This will output different values for sensitivity, specificity, F1-measure and Matthews coefficient. Our strategy is based on the Youden’s index, i.e., the threshold that maximizes sensitivity + specificity - 1.
We can impute the missing values using the k-nearest neighbour method. It means that the individuals having similar characteristics will have similar body mass index. But it is questionable from precision medicine point of view. We can as well remove these individuals from the data set on the basis of biological differences between individuals even with very similar characteristics.
We first split the data into the 75%/25% training and test sets using rsample package of tidymodels framework. We do stratification based on the stroke variable to ensure that the proportion of stroke cases is similar across these splits. In what follow, we use set.seed to ensure reproducibility.
set.seed(123)
dat_split <- initial_split(dat, prop = 3/4, strata = stroke)
dat_train <- training(dat_split)
dat_test <- testing(dat_split)
dat_cv <- vfold_cv(dat_train)
We use logistic regression, also models of higher capacity, capable of learning non-linear decision boundaries, such as random forest and an improved version of the gradient boosting method, extreme gradient boosting. The latter models are particularly suitable for data with many categorical variables.
We train logistic regression with and without SMOTE. The latter generates synthetic samples for the rare class based on the k-nearest neighbour method. The proportion of synthetic examples can be controlled thanks to the over_ratio option in the recipe function. The number of synthetic samples generated affect sensitivity and specificity: higher number of synthetic samples will improve sensitivity, and degrade specificity.
We use parsnip package of tidymodels to train logistic regression model and random forest. There is no implementation of xgb in tidymodels, so we use xgboost package. We do a grid search for hyperparameter tuning.
ground_truth <- factor(dat_test$stroke, levels = c(1, 0))
dat_recipe <- recipe(stroke ~ ., data = dat_train) %>% step_impute_knn(all_predictors())
lr_model <- logistic_reg() %>% set_engine("glm") %>% set_mode("classification")
lr_workflow <- workflow() %>% add_model(lr_model) %>% add_recipe(dat_recipe)
lr_fit <- fit(lr_workflow, data = dat_train)
lr_fit_extracted <- extract_fit_engine(lr_fit)
lr_probs1 <- predict(lr_fit, dat_test, type = "prob")
lr_result1 <- evaluate_model_fit(ground_truth, lr_probs1$.pred_1)
## Setting levels: control = 1, case = 0
## Setting direction: controls > cases
and with SMOTE:
dat_smote_recipe <- recipe(stroke ~ ., data = dat_train) %>%
step_impute_knn(all_predictors()) %>%
step_dummy(all_nominal_predictors()) %>%
step_smote(stroke, over_ratio = 0.5)
lr_model2 <- logistic_reg() %>% set_engine("glm") %>% set_mode("classification")
lr_workflow2 <- workflow() %>% add_model(lr_model2) %>% add_recipe(dat_smote_recipe) ## use smote data recipe
lr_fit2 <- fit(lr_workflow2, data = dat_train)
lr_fit_extracted2 <- extract_fit_engine(lr_fit2)
lr_probs2 <- predict(lr_fit2, dat_test, type = "prob")
lr_result2 <- evaluate_model_fit(ground_truth, lr_probs2$.pred_1)
## Setting levels: control = 1, case = 0
## Setting direction: controls > cases
rf_model <- rand_forest() %>% set_args(mtry = tune()) %>%
set_engine("ranger", importance = "impurity") %>%
set_mode("classification")
rf_workflow <- workflow() %>%
add_recipe(dat_smote_recipe) %>%
add_model(rf_model)
rf_grid <- expand.grid(mtry = c(2, 3, 4, 5))
rf_tune_results <- rf_workflow %>% tune_grid(resamples = dat_cv, grid = rf_grid, metrics = metric_set(roc_auc))
rf_tune_results %>% collect_metrics()
## # A tibble: 4 × 7
## mtry .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 2 roc_auc binary 0.819 10 0.00867 Preprocessor1_Model1
## 2 3 roc_auc binary 0.817 10 0.0109 Preprocessor1_Model2
## 3 4 roc_auc binary 0.809 10 0.0116 Preprocessor1_Model3
## 4 5 roc_auc binary 0.808 10 0.0129 Preprocessor1_Model4
param_final <- rf_tune_results %>% select_best(metric = "roc_auc")
rf_workflow <- rf_workflow %>% finalize_workflow(param_final)
rf_fit <- rf_workflow %>% last_fit(dat_split)
#test_performance <- rf_fit %>% collect_metrics()
rf_probs <- as.data.frame(rf_fit$.predictions)$.pred_1
rf_result <- evaluate_model_fit(ground_truth, rf_probs)
## Setting levels: control = 1, case = 0
## Setting direction: controls > cases
dat_smote_recipe <- dat_smote_recipe %>% prep()
X_train <- dat_train %>% select(-stroke)
X_test <- dat_test %>% select(-stroke)
X_train <- bake(dat_smote_recipe, X_train)
X_test <- bake(dat_smote_recipe, X_test)
y_train <- as.numeric(dat_train$stroke)-1
y_test <- as.numeric(dat_test$stroke)-1
X_train_matrix <- as.matrix(X_train)
X_test_matrix <- as.matrix(X_test)
dtrain <- xgb.DMatrix(data = X_train_matrix, label = y_train)
dtest <- xgb.DMatrix(data = X_test_matrix, label = y_test)
params_list <- expand.grid(
eta = c(0.01, 0.1, 0.3),
max_depth = c(3, 6, 9, 12),
subsample = c(0.3, 0.5, 0.7, 0.9),
colsample_bytree = c(0.2, 0.5, 0.8)
)
cv_results <- xgb_tune(params_list, dtrain, 100)
cv_summary <- rbindlist(cv_results)
best_params <- cv_summary[which.max(auc)]
### we train xgb with the best hyperparameter configuration on the entire training set.
xgb_final_model <- xgb.train(
params = as.list(best_params),
data = dtrain,
nrounds = best_params$best_iteration,
watchlist = list(train = dtrain),
verbose = 0
)
## [15:15:46] WARNING: src/learner.cc:767:
## Parameters: { "auc", "best_iteration" } are not used.
xgb_probs <- predict(xgb_final_model, dtest)
xgb_result <- evaluate_model_fit(ground_truth, xgb_probs)
## Setting levels: control = 1, case = 0
## Setting direction: controls > cases
According to Matthiews correlation coefficient, all methods perform weakly (better than random). Their performances are comparable, but logistic regression with smote stands out in terms of AUC, F1-measure, Matthews coefficient and specificity. We see that generating synthetic examples overall improved the predictive performance of logistic regression model at the cost of sensitivity.
results_df <- tribble(
~Model, ~Sensitivity, ~Specificity, ~F1, ~AUC, ~Matthews,
"Logistic Reg", lr_result1$sensi, lr_result1$speci, lr_result1$f1, lr_result1$aucroc, lr_result1$matcc,
"Logistic Reg Smote", lr_result2$sensi, lr_result2$speci, lr_result2$f1, lr_result2$aucroc, lr_result2$matcc,
"Random Forest Smote", rf_result$sensi, rf_result$speci, rf_result$f1, rf_result$aucroc, rf_result$matcc,
"XGBoost Smote", xgb_result$sensi, xgb_result$speci, xgb_result$f1, xgb_result$aucroc, xgb_result$matcc
)
results_long <- results_df %>% pivot_longer(cols = -Model, names_to = "Metric", values_to = "Value")
ggplot(results_long, aes(x = Metric, y = Value, fill = Model)) +
geom_bar(stat = "identity", position = position_dodge(), width = 0.7) +
scale_fill_brewer(palette = "Set2") +
scale_color_brewer(palette = "Set2") +
labs(
title = "Performance comparison",
y = "Value",
x = ""
) +
theme_minimal()+
theme(plot.title = element_text(hjust = 0.5, size = 16))
p1 <- plot_roc(data.frame(ground_truth, probs=lr_probs2$.pred_1), "Logistic Regression")
p2 <- plot_roc(data.frame(ground_truth, probs=rf_probs), "Random Forest")
p3 <- plot_roc(data.frame(ground_truth, probs=xgb_probs), "XGB")
grid.arrange(p1, p2, p3, ncol = 4)
From the confusion matrix, we see that a lot of individuals are predicted as the stroke case, since we chose lower cut-off threshold from the ROC curve. Logistic regression has smaller number of false positives.
p1 <- plot_confusion_matrix(ground_truth, lr_result2$pred_class, round(lr_result2$matcc, 2), "Logistic regression")
p2 <- plot_confusion_matrix(ground_truth, rf_result$pred_class, round(rf_result$matcc, 2), "Random forest")
p3 <- plot_confusion_matrix(ground_truth, xgb_result$pred_class, round(xgb_result$matcc, 2), "XGB")
grid.arrange(p1, p2, p3, ncol = 3)
Now let us look at the distribution of incorrect predictions for logistic regression (trained with smote). We consider important variables such as age and average glucose level. Since age is a very important risk factor for stroke, as expected, there are many false positives for older age. There are many false postivies across all values of average glucose but they are mostly concetrated around its two modes.
preds <- lr_result2$pred_class
dat_test_preds <- dat_test
dat_test_preds$preds <- factor(preds, levels=c(1,0))
dat_test_preds$outcome <- with(dat_test_preds, ifelse(
stroke == 1 & preds == 1, "True Positive",
ifelse(stroke == 0 & preds == 0, "True Negative",
ifelse(stroke == 0 & preds == 1, "False Positive", "False Negative"))))
p1 <- ggplot(dat_test_preds, aes(x = avg_glucose_level, fill = factor(stroke))) +
geom_density(alpha = 0.4) +
# Rug plot for incorrect predictions
geom_rug(data = subset(dat_test_preds, outcome %in% c("False Positive", "False Negative")),
aes(color = outcome),
sides = "b", alpha = 0.7) +
# Manual fill colors for stroke = 0 and 1
scale_fill_manual(values = c("0" = "#A6CEE3", "1" = "#33A02C"),
labels = c("No Stroke", "Stroke"),
name = "Stroke Status") +
# Manual colors for rug (incorrect predictions)
scale_color_manual(values = c("False Positive" = "blue", "False Negative" = "red"),
name = "Incorrect Prediction") +
labs(x = "Average glucose level", y = "Density") +
theme_minimal(base_size = 14)
p2 <- ggplot(dat_test_preds, aes(x = age, fill = factor(stroke))) +
geom_density(alpha = 0.4) +
# Rug plot for incorrect predictions
geom_rug(data = subset(dat_test_preds, outcome %in% c("False Positive", "False Negative")),
aes(color = outcome),
sides = "b", alpha = 0.7) +
# Manual fill colors for stroke = 0 and 1
scale_fill_manual(values = c("0" = "#A6CEE3", "1" = "#33A02C"),
labels = c("No Stroke", "Stroke"),
name = "Stroke Status") +
# Manual colors for rug (incorrect predictions)
scale_color_manual(values = c("False Positive" = "blue", "False Negative" = "red"),
name = "Incorrect Prediction") +
labs(x = "Age", y = "Density") +
theme_minimal(base_size = 14)
grid.arrange(p1, p2, ncol=2)
For deployment, we train logistic regression on the entire dataset, where we impute the missing values of BMI and generate synthetic examples of the rare class.
final_data <- dat
final_recipe <- recipe(stroke ~ ., data = final_data) %>%
step_impute_knn(all_predictors()) %>%
step_dummy(all_nominal_predictors()) %>%
step_smote(stroke, over_ratio = 0.5)
final_workflow <- workflow() %>%
add_model(lr_model2) %>%
add_recipe(final_recipe)
final_fit <- fit(final_workflow, data = final_data)
saveRDS(final_fit, file = "stroke_model.rds")
The prevalence of stroke cases is approximately 5%, making this a highly imbalanced dataset from a classification perspective. Data visualization revealed considerable overlap between the class-conditional distributions of key variables such as body mass index, average glucose level, heart disease, and age. This indicates that no single covariate is clearly discriminative for stroke classification.
Statistical analysis confirmed that age is a particularly important risk factor: the majority of stroke cases occur above the age of ~60. Other significant predictors include hypertension and average glucose level. In contrast, gender does not appear to play a substantial role, with stroke occurring at similar rates across male and female individuals.
To address class imbalance problem, we applied SMOTE to synthetically generate examples of the stroke class and we used a lower-than-standard probability threshold (less than 0.5) to classify strokes. These manipulations increase model sensitivity at the cost of more false positives. We used a parametric method such as logistic regression, and tree-based models capable of learning highly non-linear decision boundaries and well adapted to data with many categorical variable: random forest and gradient boosting. All models are comparable regarding their predictive performance: they are better than a random guess according to Matthews correlation coefficient which is computed based on the whole confusion matrix. However logistic regression stood out with respect to AUC, F1 measures, Matthews correlation coefficient and specificity.
There is a clear trade-off between sensitivity and specificity: improving sensitivity inevitably results in more false positives and some false negatives. For example, adjusting the classification threshold alters the balance between these metrics. This is an important epidemiological question: what are the consequences of generating many false positives in order to detect a rare but serious condition like stroke?